package com.googlecode.gaal.vis;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

import com.googlecode.gaal.data.api.SymbolTable;
import com.googlecode.gaal.suffix.api.BinaryIntervalTree.BinaryNode;
import com.googlecode.gaal.suffix.api.IntervalTree;
import com.googlecode.gaal.suffix.api.IntervalTree.Interval;
import com.googlecode.gaal.suffix.api.SuffixTree.Node;
import com.googlecode.gaal.vis.api.Drawing;
import com.googlecode.gaal.vis.impl.TikzConstants;

public class TreeVisualizer<S> {

    private final SymbolTable<S> symbolTable;
    private int maxDepth;
    private final boolean intervalLabels;
    private Drawing drawing;
    private Map<Interval, Integer> maxDepthNodes;
    private Map<Interval, Integer> styleMap;

    public TreeVisualizer(SymbolTable<S> symbolTable, boolean intervalLabels) {
        this.symbolTable = symbolTable;
        this.intervalLabels = intervalLabels;
    }

    public Map<Interval, Integer> getStyleMap() {
        return styleMap;
    }

    public <E extends Interval> Map<Interval, Integer> visualizeTree(Drawing drawing, IntervalTree<E> tree) {
        return visualizeTree(drawing, tree, -1);
    }

    public <E extends Interval> Map<Interval, Integer> visualizeTree(Drawing drawing, IntervalTree<E> tree, int maxDepth) {
        this.drawing = drawing;
        this.maxDepth = maxDepth;
        maxDepthNodes = new HashMap<Interval, Integer>();
        styleMap = new HashMap<Interval, Integer>();
        visualizeNode(tree.top(), null, -1, 0, 0);
        return maxDepthNodes;
    }

    private void visualizeNode(Interval node, Interval parent, int parentNumber, int parentStyle, int depth) {

        int x = node.left() + node.size() / 2;
        int y = -depth;
        int style = parentStyle;

        if (parent != null && !node.isTerminal() && (node.lcp() - parent.lcp() > 0)) {
            style = styleMap.size() + 1;
            styleMap.put(node, style);
        }

        String nodeLabel = (intervalLabels ? nodeLabel = getIntervalLabel(node) : "");

        int number = drawing.drawNode(x, y, nodeLabel, TikzConstants.NODE_STYLES[style], node.isTerminal());

        if (parent != null) {
            String edgeLabel = (intervalLabels ? "" : symbolTable.toString(node.edgeLabel(parent), ""));
            drawing.drawEdge(number, parentNumber, edgeLabel, TikzConstants.EDGE_STYLES[parentStyle]);
        }

        if (!node.isTerminal() && ((maxDepth == -1) || (depth < maxDepth))) {
            if (node instanceof BinaryNode) {
                @SuppressWarnings("unchecked")
                BinaryNode<? extends Interval> binaryInterval = (BinaryNode<? extends Interval>) node;
                visualizeNode(binaryInterval.leftChild(), node, number, style, depth + 1);
                visualizeNode(binaryInterval.rightChild(), node, number, style, depth + 1);
            } else if (node instanceof Node) {
                Node n = (Node) node;
                Iterator<Node> childIterator = n.childIterator();
                while (childIterator.hasNext()) {
                    visualizeNode(childIterator.next(), node, number, style, depth + 1);
                }
            }
        } else if (depth == maxDepth) {
            maxDepthNodes.put(node, number);
        }
    }

    private String getIntervalLabel(Interval node) {
        if (node.isTerminal()) {
            return Integer.toString(node.left());
        } else {
            return String.format("%d..%d", node.left(), node.right());
        }
    }
}
